Enabled Weighted Sampling#635
Conversation
kmontemayor2-sc
left a comment
There was a problem hiding this comment.
Thanks Matt! Me and the robots did a first pass, it's possible they're imagining some of the issues here but I figured I'd flag :)
| for edge_type, graph_partition_data in partitioned_edge_index.items() | ||
| if graph_partition_data.weights is not None | ||
| } | ||
| edge_weights = weights_by_type if weights_by_type else None |
There was a problem hiding this comment.
does this var do anything? can we remove?
| """Registers per-edge sampling weights to the partitioner. | ||
|
|
||
| Weights must be a 1-D float tensor of shape ``[num_edges]``, one scalar per edge. | ||
| Must be called after ``register_edge_index()`` and before ``partition()``. |
There was a problem hiding this comment.
Why is this the case? Where does this fail otherwise?
There was a problem hiding this comment.
is it. just so we can check that they are the same shape later? Because we do self._edge_index is not None which implies we'd skip this check.
There was a problem hiding this comment.
More robot findings:
Lines 188–330 — _partition_edge_index_and_edge_features ignores _edge_weights
The override never reads self._edge_weights, never partitions it, and never sets GraphPartitionData.weights. Because build_dataset_from_task_config_uri defaults to range partitioning (dataset_factory.py:643-646), any user who calls register_edge_weights while using the default config-driven path
silently loses all weights. Downstream DistDataset._has_edge_weights becomes False, and validate_with_weight fails with a confusing "no edge weights registered" error even though weights were registered.
Fix: Mirror DistPartitioner._partition_edge_index_and_edge_features. Roughly:
After line 223
weight_tensor = None
if self._edge_weights is not None and edge_type in self._edge_weights:
weight_tensor = self._edge_weights[edge_type]
Augment input_data on lines 229/234 to include weight_tensor (e.g. (..., edge_feat, weight_tensor))
After _partition_by_chunk on line 246, slice partitioned weights back out of res_list
Set GraphPartitionData(..., weights=partitioned_weights) in both branches at lines 308/320
del self._edge_weights[edge_type] and clear when empty, mirroring the edge_feat cleanup at 253-262
Add a regression test for DistRangePartitioner + register_edge_weights.
| ], | ||
| rank: int, | ||
| tf_dataset_options: TFDatasetOptions = TFDatasetOptions(), | ||
| weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None, |
There was a problem hiding this comment.
should we update the doc string for the new arg?
There was a problem hiding this comment.
Also it seems a little specific to put the weight_edge_feat_name here when _data_loading_process is kind of a generic function? Or do you think that's fine for now? I'm not sure how else we'd address this.
There was a problem hiding this comment.
More robot findings:
Lines 606–621 — Warning promises uniform fallback for partial heterogeneous weights, but GLT will crash
weights_by_type = {
edge_type: graph_partition_data.weights
for edge_type, graph_partition_data in partitioned_edge_index.items()
if graph_partition_data.weights is not None
}
edge_weights = weights_by_type if weights_by_type else None
if weights_by_type:
missing = set(partitioned_edge_index.keys()) - set(weights_by_type.keys())
if missing:
logger.warning(
f"... When with_weight=True, edge types without weights "
f"will fall back to uniform sampling."
)
GLT does not fall back to uniform sampling in this case. graphlearn_torch/sampler/neighbor_sampler.py:104-113 unconditionally instantiates pywrap.CPUWeightedSampler for every edge type when with_weight=True. For unweighted types, Topology keeps edge_weights = torch.empty(0), Graph::InitCPUGraphFromCSR
leaves edge_weight_ unset, and CPUWeightedSampler::Sample dereferences a null prob pointer in std::discrete_distribution<> — undefined behavior, typically a crash.
Fix: Either tighten BaseDistLoader.validate_with_weight to require weights for every edge type when with_weight=True, or synthesize all-ones weight tensors here for missing edge types:
if weights_by_type:
for edge_type, gpd in partitioned_edge_index.items():
if gpd.weights is None:
if weights_by_type:
missing = set(partitioned_edge_index.keys()) - set(weights_by_type.keys())
if missing:
logger.warning(
f"... When with_weight=True, edge types without weights "
f"will fall back to uniform sampling."
)
GLT does not fall back to uniform sampling in this case. graphlearn_torch/sampler/neighbor_sampler.py:104-113 unconditionally instantiates pywrap.CPUWeightedSampler for every edge type when with_weight=True. For unweighted types, Topology keeps edge_weights =
torch.empty(0), Graph::InitCPUGraphFromCSR leaves edge_weight_ unset, and CPUWeightedSampler::Sample dereferences a null prob pointer in std::discrete_distribution<> — undefined behavior, typically a crash.
Fix: Either tighten BaseDistLoader.validate_with_weight to require weights for every edge type when with_weight=True, or synthesize all-ones weight tensors here for missing edge types:
if weights_by_type:
for edge_type, gpd in partitioned_edge_index.items():
if gpd.weights is None:
weights_by_type[edge_type] = torch.ones(gpd.edge_index.shape[1], dtype=torch.float32)
edge_weights = weights_by_type
Either way, rewrite the warning to reflect actual behavior. Add a test with one weighted + one unweighted sampled edge type and with_weight=True.
| f"weight_edge_feat_name '{col_name}' not found in edge feature keys " | ||
| f"for edge type {edge_type}: {feature_keys}" | ||
| ) | ||
| col_idx = feature_keys.index(col_name) |
There was a problem hiding this comment.
Robot finding - is this true? Do we support multi dim features?
Lines 185, 207 — feature_keys.index(col_name) assumes each feature key is exactly one column wide
col_idx = feature_keys.index(col_name)
weights[edge_type] = feat_tensor[:, col_idx]
keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_idx]
features[edge_type] = feat_tensor[:, keep_cols]
_concatenate_features_by_names (gigl/common/data/dataloaders.py:155-187) supports multi-dim features, so a single feature key may contribute multiple columns. Example: feature_keys = ["embedding", "weight"] where embedding has shape [N, 16] and weight has shape [N, 1] →
concat tensor has 17 columns and the weight lives at column 16, not column 1. tests/unit/common/data/dataloaders_test.py:494 only covers scalar features so doesn't catch this.
Fix: Compute the actual column offset by summing the widths of preceding features (from feat_spec[feature_key].shape or by deriving from the concatenated tensor's per-feature widths), and assert the weight feature has width 1 before squeezing:
col_widths = [serialized_tf_record_info[edge_type].feature_spec[k].shape[-1] or 1 for k in feature_keys]
col_offset = sum(col_widths[: feature_keys.index(col_name)])
weight_width = col_widths[feature_keys.index(col_name)]
assert weight_width == 1, f"weight column '{col_name}' must be width 1, got {weight_width}"
weights[edge_type] = feat_tensor[:, col_offset]
keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_offset]
features[edge_type] = feat_tensor[:, keep_cols]
Add a test with at least one multi-dim feature alongside the weight column.
|
/unit_test |
GiGL Automation@ 23:03:34UTC : 🔄 @ 24:11:45UTC : ❌ Workflow failed. |
GiGL Automation@ 23:03:35UTC : 🔄 @ 23:05:34UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 23:03:35UTC : 🔄 @ 23:13:48UTC : ✅ Workflow completed successfully. |
|
/unit_test |
GiGL Automation@ 06:09:08UTC : 🔄 @ 06:19:33UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 06:09:09UTC : 🔄 @ 06:13:19UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 06:09:09UTC : 🔄 @ 07:13:24UTC : ✅ Workflow completed successfully. |
Summary
Adds native weighted edge sampling to GiGL's distributed training pipeline via GLT's
CPUWeightedSampler. When enabled, neighbors are sampled proportionally to edge weights rather than uniformly.New API
DistPartitioner.register_edge_weights(edge_weights)— registers a 1D per-edge weight tensor (homogeneous ordict[EdgeType, Tensor]for heterogeneous) before callingpartition_edge_index_and_edge_features(). Weights are partitioned alongside edge features in the same pass (co-partitioned, mirroring the node features + labels pattern).load_torch_tensors_from_tf_record(weight_edge_feat_name=...)— accepts the name of an existing edge feature column to extract as sampling weights during TFRecord loading. The column is sliced out of the feature tensor and stored inLoadedGraphTensors.edge_weights; it is never duplicated in memory.build_dataset(weight_edge_feat_name=...)— threadsweight_edge_feat_namethrough to TFRecord loading and then callsregister_edge_weights()with the extracted weights.DistNeighborLoader(with_weight=True)/DistABLPLoader(with_weight=True)— enables weighted sampling. Defaults toFalse; must be set explicitly.BaseDistLoader.validate_with_weight()— shared validation: raisesValueErrorifwith_weight=Truebut no weights are registered in the dataset; raisesNotImplementedErrorif used withPPRSamplerOptions(weight-proportional PPR residual propagation is deferred to a future PR).Implementation notes
LoadedGraphTensors.edge_weights— new field carrying extracted weights from TFRecord loading through toregister_edge_weights().GraphPartitionData.weights(field already existed) carries the partitioned weight tensor toDistDataset._initialize_graph(), which forwards it to GLT'sinit_graph(edge_weights=...).DistDataset.has_edge_weightsproperty reflects whether weights were registered at construction time.SamplingConfig.with_weightis now threaded through from the loader rather than hardcoded toFalse.DistServer.get_edge_weights_registered()andRemoteDistDataset.fetch_edge_weights_registered()propagatehas_edge_weightsacross the RPC boundary so compute nodes can validatewith_weightagainst the remote dataset.Tests
tests/unit/distributed/distributed_weighted_sampling_test.py(8 new tests):GraphPartitionData.edge_ids == FeaturePartitionData.ids), and heterogeneous partial weights (one edge type weighted, another not).tests/unit/common/data/dataloaders_test.py(1 new test):test_load_edge_weights_from_tf_record— verifies thatload_torch_tensors_from_tf_recordcorrectly extracts a named column intoedge_weights, removes it fromedge_features, and returns the right shapes and values.